{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import mindspore\n",
    "from mindspore import ops\n",
    "from mindnlp.transformers import AutoModelForSequenceClassification\n",
    "from mindnlp.transformers.models.bert.modeling_bert import BertDualForSequenceClassification\n",
    "from collections import OrderedDict\n",
    "from mindnlp._legacy.hypercomplex.tensor_decomposition.hypercomplex_td import LinearTDLayer, decompose_linear_parameters, set_new_dict_names, calculate_parameters\n",
    "import re"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "real_model = AutoModelForSequenceClassification.from_pretrained('bert-base-cased', num_labels=2)\n",
    "#create dual model\n",
    "model = BertDualForSequenceClassification(real_model.config)\n",
    "print('Parameters of model:', calculate_parameters(model))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "params_dict = OrderedDict(model.parameters_and_names())\n",
    "new_dict = OrderedDict()\n",
    "model_name = \"model\"\n",
    "layers_list = [\"query\", \"key\", \"value\", \"dense\"]\n",
    "threshold = 0.2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for name, p in params_dict.items():\n",
    "    needed_layer = False\n",
    "    for ind in map(str, range(real_model.config.num_hidden_layers)):\n",
    "        if (\".\" + str(ind) + \".\") in name:\n",
    "            needed_layer = True\n",
    "            break\n",
    "    needed_qkv = False\n",
    "    for ind in layers_list:\n",
    "        if ind in name:\n",
    "            needed_qkv = True\n",
    "            break\n",
    "    if not needed_layer or not needed_qkv:\n",
    "        new_dict[name] = p\n",
    "        continue\n",
    "    print (name, p.shape)\n",
    "\n",
    "    if name.endswith(\"weight_x\"):\n",
    "        print('Considered layer:',name)\n",
    "        print(\"\\tCompressing of\", name.replace(\".weight_x\",\"\"),\"...\")\n",
    "        wy_name = name.replace(\"_x\",\"_y\")\n",
    "        w_y = params_dict[wy_name]\n",
    "        param = ops.cat([ops.unsqueeze(p.T, -1), ops.unsqueeze(w_y.T, -1)], -1).asnumpy()\n",
    "\n",
    "        p1, p2 = decompose_linear_parameters(param, threshold)\n",
    "        rk = p1._width\n",
    "\n",
    "        bx_name = name.replace(\"weight_x\",\"bias_x\")\n",
    "        by_name = name.replace(\"weight_x\",\"bias_y\")\n",
    "        b_x = None\n",
    "        b_y = None\n",
    "        bias_flag = False\n",
    "        \n",
    "        if bx_name in params_dict or by_name in params_dict:\n",
    "            b_x = params_dict[bx_name]\n",
    "            b_y = params_dict[by_name]\n",
    "            \n",
    "        set_new_dict_names(p1, p2, name, new_dict, b_x, b_y)\n",
    "        mod_name = name.replace(\".weight_x\", \"\")\n",
    "        mod_name = re.sub('\\.([0-9]+)(\\.)?', '[\\\\1]\\\\2', mod_name)\n",
    "        \n",
    "        if bx_name in params_dict or by_name in params_dict:\n",
    "            bias_flag = (params_dict[bx_name] is not None) and (params_dict[by_name] is not None)\n",
    "\n",
    "        op = LinearTDLayer(p.shape[1], p.shape[0], bias_flag, rk)\n",
    "        op_name = (model_name + \".\" + mod_name).split(\".\", -1) [-1]\n",
    "        op_prefix = (model_name + \".\" + mod_name).replace(\".\" + op_name, \"\")\n",
    "        setattr(eval(op_prefix), op_name, op)\n",
    "    else:\n",
    "        continue"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Saving and loading the obtained checkpoint\n",
    "mindspore.save_checkpoint(new_dict, './dual_bert.ckpt')\n",
    "param_dict = mindspore.load_checkpoint('./dual_bert.ckpt')\n",
    "mindspore.load_param_into_net(model, param_dict)\n",
    "print('Parameters of obtained model:', calculate_parameters(model))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
